{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e9025a4a-b8ef-4901-b98e-753b756b028a",
   "metadata": {},
   "source": [
    "# Building a RAG chat without the langchain framework\n",
    "## To understand more in detail what's going on\n",
    "\n",
    "The technical know-how comes from Ed Donner, obviously, as well as from Sakalya Mitra & Pradip Nichite on [this gem of a blog post](https://blog.futuresmart.ai/building-rag-applications-without-langchain-or-llamaindex) I found on futuresmart.ai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b7acfb5-8bf9-48b5-a219-46f1e3bfafc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv\n",
    "import gradio as gr\n",
    "import re\n",
    "from openai import OpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19af6b8b-be29-4086-a69f-5e2cdb867ede",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports for Chroma and plotly\n",
    "\n",
    "import chromadb\n",
    "from chromadb.utils import embedding_functions\n",
    "import numpy as np\n",
    "from sklearn.manifold import TSNE\n",
    "import plotly.graph_objects as go"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc6d9ab4-816a-498c-a04c-c3838770d848",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = \"gpt-4o-mini\"\n",
    "db_name = \"chroma_db\"\n",
    "client = chromadb.PersistentClient(path=\"chroma_db\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3715b81-eed0-4412-8c01-0623ed113657",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_dotenv()\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "openai = OpenAI()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3017e1dd-d0d5-4ef4-8c72-84517a927793",
   "metadata": {},
   "source": [
    "### Making stuff at home: documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e83480a5-927b-4756-a978-520a56ceed85",
   "metadata": {},
   "outputs": [],
   "source": [
    "# items in documents are actually objects: Documents(metadata={...}, page_content=\"...\"), so we need a \"Document\" class\n",
    "# btw all the quadruple-backslash madness here is due to Windows (there might be a more efficient way, still)\n",
    "\n",
    "class Document:\n",
    "    def __init__(self, metadata, page_content):\n",
    "        self.metadata = metadata\n",
    "        self.page_content = page_content\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"Document(metadata={self.metadata}, page_content={repr(self.page_content)})\"\n",
    "\n",
    "\n",
    "documents = []\n",
    "\n",
    "def get_documents(path='.'):\n",
    "    for entry in os.listdir(path):\n",
    "        if len(re.findall(\"^\\.\", entry)) == 0:\n",
    "            full_path = os.path.join(path, entry)\n",
    "            if os.path.isdir(full_path):\n",
    "                get_documents(full_path)\n",
    "            else:\n",
    "                parent = re.sub(\"^\\.[\\\\\\\\].*[\\\\\\\\]\", \"\", os.path.dirname(full_path))\n",
    "                self = os.path.basename(full_path)\n",
    "                content = \"\"\n",
    "\n",
    "                with open(full_path, mode=\"r\", encoding=\"utf-8\") as f:\n",
    "                    content = f.read()\n",
    "                \n",
    "                doc = Document(metadata={\"source\": full_path, \"doc_type\": parent, \"self\": self}, page_content=content)\n",
    "                documents.append(doc)\n",
    "\n",
    "# where the knowledge collection lives\n",
    "directory_path = r'.\\knowledge_collection'\n",
    "get_documents(directory_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd846bc0-54d0-4802-a18b-196c396a241c",
   "metadata": {},
   "source": [
    "### Making stuff at home: chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "202b33e2-c3fe-424c-9c8e-a90e517add42",
   "metadata": {},
   "outputs": [],
   "source": [
    "eos_pattern = re.compile(r\"((?<=[.!?;])[\\s]+)|([\\n\\r]+)\")\n",
    "chunk_size = 1000\n",
    "chunks = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a19a61ec-d204-4b87-9f05-88832d03fad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for doc in documents:\n",
    "\n",
    "    sentence_ends = [end.start() for end in list(re.finditer(eos_pattern, doc.page_content)) if end.start() > chunk_size - 50]\n",
    "    start = 0\n",
    "    \n",
    "    if len(sentence_ends) == 0 and len(doc.page_content) > 5:\n",
    "        chunk = Document(metadata=doc.metadata, page_content=doc.page_content)\n",
    "        chunk.metadata['id'] = f\"{doc.metadata['source']}_chunk_\"\n",
    "        chunks.append(chunk)\n",
    "\n",
    "    else:        \n",
    "        for point in sentence_ends:\n",
    "            if point - start >= chunk_size - 50:\n",
    "                text = doc.page_content[start:point]\n",
    "                chunk = Document(metadata=doc.metadata, page_content=text)\n",
    "                chunk.metadata['id'] = f\"{doc.metadata['source']}_chunk_\"\n",
    "                chunks.append(chunk)\n",
    "                start = point\n",
    "                \n",
    "        # Add the remaining part of the text as the last chunk if it's big enough\n",
    "        if len(doc.page_content) - start > 5:\n",
    "            text = doc.page_content[start:]\n",
    "            chunk = Document(metadata=doc.metadata, page_content=text)\n",
    "            chunk.metadata['id'] = f\"{doc.metadata['source']}_chunk_\"\n",
    "            chunks.append(chunk)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "966ae50c-e0e5-403a-9465-8f26967f8922",
   "metadata": {},
   "source": [
    "### Making stuff without a framework: embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b97391c0-e55f-4e08-b0cb-5e62fb119ae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configure sentence transformer embeddings\n",
    "embeddings = embedding_functions.SentenceTransformerEmbeddingFunction(\n",
    "    model_name=\"all-MiniLM-L6-v2\"\n",
    ")\n",
    "\n",
    "collection_name = \"documents_collection\"\n",
    "\n",
    "try:\n",
    "    client.delete_collection(collection_name)\n",
    "except ValueError:\n",
    "    print(f\"{collection_name} doesn't exist yet\")\n",
    "\n",
    "# Create collection\n",
    "collection = client.get_or_create_collection(\n",
    "    name=collection_name,\n",
    "    embedding_function=embeddings\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5222dfec-8cf4-4e87-aeb8-33d0f3b3b5cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# adding our chunks to the \"collection\"\n",
    "\n",
    "for chunk in chunks:\n",
    "    index = chunks.index(chunk)\n",
    "    collection.add(\n",
    "        documents=chunk.page_content,\n",
    "        metadatas=chunk.metadata,\n",
    "        ids=chunk.metadata['id'] + f\"{index}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5effcada-ee5f-4207-9fa6-1fc5604b068b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def semantic_search(collection, query: str, n_results: int = 4):\n",
    "    results = collection.query(\n",
    "        query_texts=[query],\n",
    "        n_results=n_results\n",
    "    )\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99f0a366-3dcb-4824-9f33-70e07af984d8",
   "metadata": {},
   "source": [
    "## Visualizing the Vector Store\n",
    "\n",
    "The results actually look just as good with `all-MiniLM-L6-v2`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e12751ab-f102-4dc6-9c0f-313e5832b75f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prework\n",
    "\n",
    "result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
    "vectors = np.array(result['embeddings'])\n",
    "documents = result['documents']\n",
    "doc_types = [metadata['doc_type'] for metadata in result['metadatas']]\n",
    "colors = [['blue', 'red', 'orange'][['languages', 'mountains', 'regions'].index(t)] for t in doc_types]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "422e3247-2de0-44ba-82bc-30b4f739da7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
    "# (t-distributed stochastic neighbor embedding)\n",
    "\n",
    "tsne = TSNE(n_components=2, random_state=42)\n",
    "reduced_vectors = tsne.fit_transform(vectors)\n",
    "\n",
    "# Create the 2D scatter plot\n",
    "fig = go.Figure(data=[go.Scatter(\n",
    "    x=reduced_vectors[:, 0],\n",
    "    y=reduced_vectors[:, 1],\n",
    "    mode='markers',\n",
    "    marker=dict(size=5, color=colors, opacity=0.8),\n",
    "    text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n",
    "    hoverinfo='text'\n",
    ")])\n",
    "\n",
    "fig.update_layout(\n",
    "    title='2D Chroma Vector Store Visualization',\n",
    "    scene=dict(xaxis_title='x',yaxis_title='y'),\n",
    "    width=800,\n",
    "    height=600,\n",
    "    margin=dict(r=20, b=10, l=10, t=40)\n",
    ")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cff9065-de3d-4e91-8aff-c7ad750a4334",
   "metadata": {},
   "source": [
    "#### Comment: Relying on Gradio's history handling seems to be memory enough\n",
    "##### If all you need is your favorite LLM with expertise in your knowlege collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aebb676f-883e-4b2b-8420-13f2a8399e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = \"You are a helpful assistant for everything French. Give brief, accurate answers. \\\n",
    "Do not provide any information that you haven't been asked for, even if you have lots of context. \\\n",
    "If you haven't been provided with relevant context, say you don't know. Do not make anything up, only \\\n",
    "provide answers that are based in the context you have been given. Do not comment on the provided context. \\\n",
    "If the user doesn't ask for any information, engage in brief niceties and offer your expertise regarding France.\"\n",
    "\n",
    "history = [{\"role\": \"system\", \"content\": system_prompt}]\n",
    "\n",
    "def get_user_prompt(prompt):\n",
    "    # semantic search!!\n",
    "    context = semantic_search(collection, prompt)['documents'][0]\n",
    "\n",
    "    if len(context) > 0:\n",
    "        prompt += f\"\\n\\n[AUTOMATIC SYSTEM CONTEXT ADDITION] Here is some context that might be useful for answering the question:\"\n",
    "\n",
    "        for doc in context:\n",
    "            prompt += f\"\\n\\n{doc}\"\n",
    "    \n",
    "    user_prompt = {\"role\": \"user\", \"content\": prompt}\n",
    "\n",
    "    return user_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23b70162-2c4f-443e-97c8-3e675304d307",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stream_gpt(message, history):\n",
    "    messages = [{\"role\": \"system\", \"content\": system_prompt}] + history\n",
    "    messages.append(get_user_prompt(message))\n",
    "    stream = openai.chat.completions.create(\n",
    "        model=MODEL,\n",
    "        messages=messages,\n",
    "        stream=True\n",
    "    )\n",
    "    result = \"\"\n",
    "    for chunk in stream:\n",
    "        result += chunk.choices[0].delta.content or \"\"\n",
    "        yield result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ecf4a30-452d-4d41-aa60-fa62c8e2559b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Gradio\n",
    "\n",
    "gr.ChatInterface(fn=stream_gpt, type=\"messages\").launch(inbrowser=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
